from vlmeval.smp import *
from vlmeval.api.base import BaseAPI

url = "https://api.stepfun.com/v1/chat/completions"
headers = {
    "Content-Type": "application/json",
    "Authorization": "Bearer {}",
}


class StepAPI_INT(BaseAPI):

    is_api: bool = True

    def __init__(
        self,
        model: str = "step-1v-8k",
        retry: int = 10,
        wait: int = 3,
        key: str = None,
        temperature: float = 0,
        max_tokens: int = 300,
        verbose: bool = True,
        system_prompt: str = None,
        **kwargs,
    ):
        self.model = model
        self.fail_msg = "Fail to obtain answer via API."
        self.headers = headers
        self.temperature = temperature
        self.max_tokens = max_tokens
        self.system_prompt = system_prompt
        if key is not None:
            self.key = key
        else:
            self.key = os.environ.get("STEPAI_API_KEY", "")
        headers["Authorization"] = headers["Authorization"].format(self.key)

        super().__init__(
            retry=retry,
            wait=wait,
            verbose=verbose,
            system_prompt=system_prompt,
            **kwargs,
        )

    @staticmethod
    def build_msgs(msgs_raw):
        messages = []
        message = {"role": "user", "content": []}

        for msg in msgs_raw:
            if msg["type"] == "image":
                image_b64 = encode_image_file_to_base64(msg["value"])
                message["content"].append(
                    {
                        "image_url": {"url": "data:image/webp;base64,%s" % (image_b64)},
                        "type": "image_url",
                    }
                )
            elif msg["type"] == "text":
                message["content"].append({"text": msg["value"], "type": "text"})

        messages.append(message)
        return messages

    def generate_inner(self, inputs, **kwargs) -> str:
        print(inputs, "\n")
        payload = dict(
            model=self.model,
            max_tokens=self.max_tokens,
            temperature=self.temperature,
            messages=self.build_msgs(msgs_raw=inputs),
            **kwargs,
        )
        response = requests.post(url, headers=headers, data=json.dumps(payload))
        ret_code = response.status_code
        ret_code = 0 if (200 <= int(ret_code) < 300) else ret_code

        answer = self.fail_msg
        try:
            resp_struct = json.loads(response.text)
            answer = resp_struct["choices"][0]["message"]["content"].strip()
        except Exception as err:
            if self.verbose:
                self.logger.error(f"{type(err)}: {err}")
                self.logger.error(
                    response.text if hasattr(response, "text") else response
                )

        return ret_code, answer, response


class Step1V_INT(StepAPI_INT):

    def generate(self, message, dataset=None):
        return super(StepAPI_INT, self).generate(message)
